Skip to content

[RFC] Add TrainingModule and SGD JNI + PTE-only Training Workflow #12247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

georgehong
Copy link
Contributor

@georgehong georgehong commented Jul 7, 2025

Summary

Adds JNI for SGD and TrainingModule, including a unit test that mirrors train.cpp for a simple XOR example. Also makes the following change:

  • Refactor jni_layer.cpp JTensor <--> Tensor conversion to be a general TensorHybrid utility. This is useful for TrainingModule classes that move maps of Tensors around.
  • Updates android_test_setup.sh to match the pushd-popd directory movement for consistency and flexibility. This is also used to fix errors with generating the XOR files.

Training dependencies are already enabled for Java JNI library, so we skip adding additional guard flags.

Test plan

Updated XOR tests that check .pte only convergence workflow.

sh scripts/build_android_library.sh
sh executorch_android/android_test_setup.sh // Creates xor.ptd, xor.pte, and xor_full.pte files.

./gradlew :executorch_android:connectedAndroidTest // Added unit test to check toy model convergence loss < 0.01

For the XOR tests, the device logs will show convergence values:

I testTrainXOR: Step 0, Loss 0.683540, Input [1, 0], Prediction 1, Label 1
...
I testTrainXOR: Step 4500, Loss 0.000994, Input [0, 0], Prediction 0, Label 0

Copy link

pytorch-bot bot commented Jul 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/12247

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 11 Pending

As of commit b4d7ada with merge base ba19c75 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 7, 2025
Copy link

github-actions bot commented Jul 7, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@georgehong georgehong force-pushed the gh/georgehong/training_jni_pte_only branch from a6e15a7 to c008409 Compare July 8, 2025 07:56
@georgehong georgehong changed the title Update and test JNI Training entrypoints slightly to allow for PTE-only workflows [RFC] Add TrainingModule and SGD JNI + PTE-only Training Workflow Jul 8, 2025
@georgehong georgehong force-pushed the gh/georgehong/training_jni_pte_only branch from c008409 to aba87ed Compare July 8, 2025 08:05
@georgehong georgehong requested a review from JacobSzwejbka July 8, 2025 08:05
@facebook-github-bot
Copy link
Contributor

@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473.

facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor);
};

class JEValue : public facebook::jni::JavaClass<JEValue> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm does this not already exist for inference?

* @param nesterov Whether to use Nesterov momentum
* @return new {@link org.pytorch.executorch.SGD} object
*/
public static SGD create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesnt have to be this diff but would it be more "java-y" to have builder classes?

new SGDBuilder().learning_rate().buildSGD();

}

@DoNotStrip
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q. What are these "native" apis?

As title, adds wrappers together with unit test based on XOR train.cpp example.
@georgehong georgehong force-pushed the gh/georgehong/training_jni_pte_only branch from aba87ed to b4d7ada Compare July 8, 2025 18:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants